In [31]:
%matplotlib inline
from sklearn.datasets import fetch_mldata
import matplotlib.pyplot as plt
import numpy as np
from IPython.html.widgets import interact
from IPython.display import display
import core
In [32]:
mnist = fetch_mldata("MNIST original", data_home="./MNIST dataset")
In [33]:
mnist.target.shape
Out[33]:
(70000,)
In [38]:
def show_examples(i):
plt.matshow(mnist.data[i].reshape((28,28)), cmap='Greys_r')
display(mnist.target[i])
interact(show_examples, i=[60000,70000-1])
6.0
In [57]:
phipps = core.gen_network([784,40,10])
phipps
Out[57]:
[array([[-1.18045863, 0.99909873, 1.48408414, ..., 0.07700514,
0.02055552, -1.85437819],
[ 0.82048939, -0.6677628 , 0.09726159, ..., -0.25809884,
-0.81745753, 0.2215101 ],
[-1.82209173, -0.30833717, 0.70276417, ..., 1.26639302,
-1.2415207 , 1.69466223],
...,
[ 0.04177277, 0.87976371, -2.34645315, ..., -0.04584682,
-0.86120687, -0.19235447],
[ 0.37467598, 0.55056671, 0.74541514, ..., -0.18803911,
0.4837592 , -0.55211067],
[-0.30821086, 0.60741 , -0.7220094 , ..., 1.05986457,
-0.75256015, -0.67976349]]),
array([[ 9.75242254e-01, 1.41007556e-02, 5.08885126e-01,
-8.89233039e-01, 6.96220748e-01, 1.24515760e+00,
1.82054426e+00, -7.69981773e-01, 1.16082585e+00,
-7.47697535e-01, -4.78418942e-01, -1.15748825e+00,
2.42877335e+00, 7.82008423e-01, 1.09819442e+00,
-1.62097191e+00, 4.47348051e-01, 1.90487661e-02,
9.31129681e-01, -1.58091681e+00, 2.47950177e-01,
2.64875066e-01, 3.45531524e-01, 8.41898548e-01,
7.25186717e-01, -3.43024341e-01, -8.66679494e-01,
1.09679531e+00, 1.62208156e+00, 7.47115079e-01,
-9.51291590e-02, 6.70098738e-01, 6.12628265e-01,
-3.97879537e-01, -6.58478779e-01, -9.71701002e-01,
2.08484079e+00, 7.34469878e-01, -4.60973104e-01,
-1.45835169e+00, -1.64901872e-02],
[ 6.85883089e-02, -8.04937495e-01, -8.89101505e-02,
6.94017611e-01, -6.35280609e-01, -3.76968710e-01,
5.05434530e-01, 1.80726206e+00, -1.58980649e-01,
1.09974245e+00, -1.32264429e+00, 1.90657656e+00,
-9.33108409e-01, -1.07977908e+00, 1.96674393e-01,
-2.17243315e+00, 9.60888274e-01, 1.20506637e-01,
-7.45032670e-02, -1.19349150e-01, 1.62598708e-01,
3.81075533e-01, 1.34310184e-01, -1.00264873e+00,
-4.00769328e-01, 1.06994380e+00, -1.19756697e+00,
-1.34582355e-01, 3.49737477e-01, -1.24913732e+00,
7.49004377e-01, 1.64415550e-03, -4.47172894e-01,
-1.12367000e+00, -5.11820922e-02, 5.85872158e-01,
1.80032809e-01, -3.44284483e-01, 5.50596298e-01,
8.61070705e-01, -2.64503394e+00],
[ -7.17755040e-01, 1.25873950e+00, -6.15900752e-01,
1.88176084e+00, -1.38957513e-01, -8.32114546e-01,
1.09082489e+00, 3.07008242e-01, -2.49707996e-01,
2.87873512e-01, -1.80971164e+00, -1.62605595e-01,
4.96582846e-01, 4.28194916e-01, 9.20176978e-01,
-7.84112068e-01, 1.02888091e+00, -2.11363380e+00,
2.12364776e+00, -4.37296905e-01, 9.55451706e-01,
-1.77394898e+00, 9.21702520e-01, -7.37541791e-01,
4.52255398e-01, 2.89586567e-01, -6.90722916e-01,
8.70102920e-01, -6.77134006e-01, 4.06319051e-01,
-2.02996231e+00, -6.87213336e-01, -4.19412398e-01,
-7.41710604e-01, -1.44594966e+00, -1.12372177e+00,
-5.78252725e-01, -1.75459687e+00, -1.65693468e+00,
-5.25230809e-01, -5.26290114e-01],
[ -9.53162131e-01, -1.47788376e+00, -5.36040341e-01,
1.90820653e+00, 2.93999947e-01, -4.02311677e-01,
3.08598136e-01, -9.61287383e-01, -1.65883699e+00,
-9.10116003e-01, -1.07168199e+00, -4.07681526e-01,
-2.87004927e-01, -7.38379411e-01, 2.76408429e+00,
-1.17446777e+00, 3.10023319e-01, -3.18178327e-01,
1.07402716e+00, -8.17580380e-01, -1.28361483e-01,
-1.83752294e+00, 6.60678966e-01, -9.61288166e-01,
-3.64238899e-01, 4.01921252e-01, 5.08135350e-01,
-1.06711128e+00, -1.17119249e-01, 2.27852832e+00,
5.60911679e-01, 9.40698356e-01, -9.07170858e-02,
-7.76294600e-01, 1.72900748e+00, 6.02730236e-01,
1.92762248e+00, -4.27226816e-01, -1.66415506e+00,
1.55091924e-01, -4.54765886e-01],
[ 7.20027740e-01, 1.10064750e+00, 2.19248269e-01,
6.73060949e-01, -1.41953696e+00, 6.79567330e-01,
8.70526227e-01, -1.18067024e+00, 1.68951393e+00,
-3.82680568e-01, 2.82609687e-01, 5.91561127e-01,
4.64532361e-01, -6.80904187e-01, -4.90137425e-01,
-6.12901203e-02, 1.75664090e+00, 2.05275775e-01,
1.05732823e+00, 1.65174147e+00, -1.93911139e-02,
7.85811135e-01, -4.86033874e-01, 1.05368517e-01,
1.04481946e+00, 3.25097521e-02, -6.78813295e-01,
7.12082843e-01, 1.88914165e+00, -1.46240649e-02,
3.07209784e-02, -1.29253366e+00, 1.13069790e+00,
-5.10839439e-01, 1.25137804e+00, -5.17807563e-02,
7.07975740e-01, -1.67950981e+00, 4.52969626e-01,
-7.32450809e-01, -5.97715714e-01],
[ -3.07763763e+00, -2.39689385e+00, 2.23819523e+00,
-2.46279487e-01, 5.25017005e-01, -5.11388204e-01,
1.62351389e-01, 1.94517355e-01, -1.81335620e+00,
1.04459390e+00, -2.00430906e-01, 7.82137929e-01,
6.34216747e-01, 5.73420815e-01, -2.08003095e-02,
-1.14274717e+00, -1.63230120e+00, -1.50928028e+00,
1.13730949e+00, -5.86315859e-01, 1.34773743e+00,
2.47043737e+00, -1.12148644e+00, -3.01146849e-02,
-6.39609418e-01, 1.09726277e+00, 8.04369753e-01,
6.68984137e-01, -5.75887384e-01, -4.12289309e-01,
2.47163189e+00, 1.04854420e-01, -1.12389633e+00,
-5.51742943e-01, 5.66546160e-01, 1.20872164e+00,
-1.04900076e+00, -4.01879701e-01, 7.54082067e-01,
-2.11575769e+00, -5.86761088e-01],
[ -1.73181837e-01, 2.32205903e+00, -1.47600729e-01,
6.04833565e-01, 1.07243579e+00, -2.49656365e-02,
-8.95177703e-01, -2.82877057e-01, -9.32747954e-01,
-4.86338696e-01, -9.41220140e-01, 1.50764719e-01,
3.30122407e-01, -4.97690278e-01, 2.08684342e-01,
9.19951980e-01, 7.70242391e-01, -7.67422983e-01,
1.42813798e+00, -5.66349541e-01, -1.12064859e+00,
-6.47902792e-01, -1.35052063e+00, 2.04928533e+00,
7.05055805e-01, -4.95161117e-01, -4.31363714e-01,
-1.14731386e+00, 2.22643001e-01, 1.59124038e+00,
3.88060600e-01, -1.10747728e-01, -6.17296289e-01,
2.84934966e-01, -1.84627412e+00, 7.34661641e-01,
2.33893340e-02, 1.38324164e+00, 1.10011516e+00,
1.58857752e+00, -4.93708241e-01],
[ -1.93376335e+00, 5.04884193e-01, 1.05666633e+00,
8.25973188e-01, -1.92519725e+00, 6.84861819e-02,
6.16630412e-01, 5.78388230e-01, 2.76796947e+00,
1.21815848e+00, 5.63377432e-03, -7.31650783e-01,
2.22964051e+00, 1.34501817e-01, 1.92525852e-01,
2.38343821e-02, -7.53133231e-02, -8.78864832e-01,
-2.66117090e-02, 7.64527104e-01, -7.68422175e-01,
-7.42194566e-01, 8.29575546e-01, 1.92687853e-01,
-1.13093905e+00, 7.79662200e-02, -1.59124547e+00,
1.09491924e+00, -9.90646528e-02, 1.91288543e+00,
-2.31966735e-01, -5.50366494e-01, 9.29403579e-01,
3.27988516e-01, 1.41739704e+00, -4.04344941e-02,
2.24447820e-01, -1.00858039e+00, -1.07560443e+00,
2.13582799e-01, -4.20494767e-03],
[ 4.29793534e-01, -4.67509472e-01, 5.63665989e-02,
1.23153460e+00, -1.82049585e-01, 3.75116733e-01,
-7.68187321e-01, -9.99732429e-01, 3.80840822e-01,
-3.31721457e-01, 1.68333789e-01, 5.81615286e-01,
4.32785867e-01, -3.18269734e-01, -1.72406876e+00,
-5.04925733e-01, 5.94999764e-02, 1.17032932e+00,
1.43017119e+00, 7.39077142e-01, -6.97929026e-01,
-1.61578240e-01, 2.05903424e-01, -3.53550538e-01,
1.23904219e+00, -1.30681204e-02, 1.34279534e+00,
1.07004182e+00, 2.27285518e+00, -1.18475655e+00,
-1.38341022e+00, -9.60055939e-02, -8.33514450e-01,
3.74992393e-01, -6.71278615e-01, -1.16839763e-01,
2.89831294e-01, 1.57833465e+00, -3.07592157e-01,
7.33935168e-01, 3.36846543e-01],
[ -1.19832585e-01, -7.61416242e-01, 3.37576099e-01,
5.23897538e-01, 1.52908064e+00, -3.69910056e-01,
-1.31772167e+00, -4.65631741e-01, -4.57807040e-01,
-1.56974155e+00, -1.60390719e+00, -6.57672471e-02,
-2.01546300e+00, -5.93323070e-01, -2.78441156e-01,
6.06972189e-01, -3.48189942e-01, -1.27090146e+00,
3.89334447e-01, -1.31603870e+00, -2.92827181e-01,
3.45542829e-01, 2.57318725e-01, -8.99442022e-01,
6.71602330e-01, 1.37179934e-01, -3.39433169e-01,
2.50900689e-03, 1.31651923e+00, -5.53613980e-03,
1.28153470e+00, -2.37858769e+00, 9.14503297e-01,
3.05983160e-01, 1.82768448e-01, -3.23782584e+00,
8.29916176e-01, 1.22781046e+00, 6.85636267e-02,
-1.59385193e+00, -7.45062918e-01]])]
In [53]:
train_data_index = np.linspace(0,60000, 60000 + 1)
np.random.seed(1)
np.random.shuffle(train_data_index)
mist_data = []
mist_target = []
for n in train_data_index:
mist_data.append(mnist.data[int(n)])
mist_target.append(mnist.target[int(n)])
In [ ]:
%%timeit -r1 -n1
core.train_network(phipps,mnist.data, mnist.target, 10, 10000, 15, 0.01)
In [64]:
core.check_net(phipps, mnist.data, mnist.target, [60000,70000-1])
9.8009800980098
980/9999
In [52]:
In [ ]:
Content source: millernj/phys202-project
Similar notebooks: